package top.lingkang.finalsql.spring;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
import top.lingkang.finalsql.sql.FinalSql;

import java.util.ArrayList;
import java.util.List;

/**
 * @author lingkang
 * 2023/1/17
 **/
@Aspect
@Component
public class FinalSqlTransactionalManagement {
    private static final ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
    @Autowired(required = false)
    private ApplicationContext context;

    @Around("@within(top.lingkang.finalsql.annotation.Transactional) || @annotation(top.lingkang.finalsql.annotation.Transactional)")
    public Object finalSqlTransactionalAop(ProceedingJoinPoint joinPoint) throws Throwable {
        List<FinalSql> finalSql = getFinalSql();
        try {
            if (finalSql.isEmpty())
                return joinPoint.proceed();

            if (threadLocal.get() == null) {
                threadLocal.set(1);
                // 只在此开启一次事务
                for (FinalSql sql : finalSql)
                    sql.beginTransaction();
            } else
                threadLocal.set(threadLocal.get() + 1);

            Object proceed = joinPoint.proceed();

            // 提交事务
            if (threadLocal.get() == 1) {
                for (FinalSql sql : finalSql)
                    sql.commitTransaction();
            } else {
                threadLocal.set(threadLocal.get() - 1);
            }

            return proceed;
        } catch (Exception e) {
            // 将事务回滚
            if (threadLocal.get() != null)
                for (FinalSql sql : finalSql)
                    sql.rollbackTransaction();
            throw e;
        }
    }

    /**
     * 多数据源情况
     */
    private List<FinalSql> getFinalSql() {
        String[] names = context.getBeanNamesForType(FinalSql.class);
        List<FinalSql> list = new ArrayList<>();
        for (String name : names)
            list.add(context.getBean(name, FinalSql.class));
        return list;
    }
}
